import math
from typing import List, Optional

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM


class MaxWindowPPLDetector:
    """
    Batch‐safe max‐window PPL detector
    """

    def __init__(
        self,
        model_name: str,
        K: int = 10,
        device: Optional[str] = None,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.K = K

        # Load tokenizer + model
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, use_fast=True, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, trust_remote_code=True, torch_dtype=torch.float16
        )

        if self.tokenizer.pad_token_id is None:
            try:
                print("Adding pad token to tokenizer...")
                self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
                self.model.resize_token_embeddings(len(self.tokenizer))
                self.tokenizer.padding_side = "right"
            except ValueError:
                print("Tokenizer does not support adding special tokens.")
                # Qwen and similar tokenizers may not allow adding; fallback
                self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
                self.tokenizer.padding_side = "right"
        
        # Move model to device
        self.model.to(self.device).eval()

    def _compute_nll_batch(self, texts: List[str]):
        # Tokenize with padding to longest in batch
        enc = self.tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            add_special_tokens=True,
        ).to(self.device)

        input_ids = enc.input_ids           # (B, T)
        attention_mask = enc.attention_mask # (B, T)

        with torch.no_grad():
            outputs = self.model(**enc)
        # Cast to float32 for stable loss computation
        logits = outputs.logits     # (B, T, V)

        # Shift so that tokens 1..T are predicted by logits at 0..T-1
        shift_logits = logits[:, :-1, :].contiguous()   # (B, T-1, V)
        shift_labels = input_ids[:, 1:].contiguous()    # (B, T-1)

        # Mask out padding tokens
        pad_id = self.tokenizer.pad_token_id
        mask = shift_labels != pad_id
        shift_labels = shift_labels.clone()
        shift_labels[~mask] = -100                     # cross_entropy ignore

        # Flatten for cross_entropy
        B, Tm1, V = shift_logits.size()
        flat_logits = shift_logits.view(-1, V)          # (B*(T-1), V)
        flat_labels = shift_labels.view(-1)             # (B*(T-1),)

        # Compute per‐token NLL, ignoring pad positions
        flat_loss = F.cross_entropy(
            flat_logits, flat_labels,
            reduction='none',
            ignore_index=-100
        )                                               # (B*(T-1),)
        loss = flat_loss.view(B, Tm1)                   # (B, T-1)

        # Compute true lengths per example (number of non-pad labels)
        lengths = mask.sum(dim=1).tolist()              # list of B ints

        return loss, lengths

    def max_window_ppl(self, nll_vec: torch.Tensor) -> float:
        L = nll_vec.size(0)
        if L < self.K:
            return float(torch.exp(nll_vec.mean()).item())
        max_ppl = 0.0
        for i in range(L - self.K + 1):
            win_nll = nll_vec[i : i + self.K].mean().item()
            max_ppl = max(max_ppl, math.exp(win_nll))
        return max_ppl

    def score(self, text: str) -> float:
        nll_batch, lengths = self._compute_nll_batch([text])
        return self.max_window_ppl(nll_batch[0, : lengths[0]])

    def score_batch(
        self,
        texts: List[str],
        batch_size: int = 8,
    ) -> List[float]:
        results: List[float] = []
        for i in range(0, len(texts), batch_size):
            chunk = texts[i : i + batch_size]
            nll_batch, lengths = self._compute_nll_batch(chunk)
            for b_idx, L in enumerate(lengths):
                results.append(self.max_window_ppl(nll_batch[b_idx, :L]))
        return results



